{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/DeepLabCut/DeepLabCut/blob/main/examples/COLAB/COLAB_transformer_reID.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TGChzLdc-lUJ"
      },
      "source": [
        "# Demo: How to use our Pose Transformer for unsupervised identity tracking of animals\n",
        "![alt text](https://images.squarespace-cdn.com/content/v1/57f6d51c9f74566f55ecf271/1628250004229-KVYD7JJVHYEFDJ32L9VJ/DLClogo2021.jpg?format=1000w)\n",
        "\n",
        "https://github.com/DeepLabCut/DeepLabCut\n",
        "\n",
        "### This notebook illustrates how to use the transformer for a multi-animal DeepLabCut (maDLC) Demo tri-mouse project:\n",
        "- load our mini-demo data that includes a pretrained model and unlabeled video.\n",
        "- analyze a novel video.\n",
        "- use the transformer to do unsupervised ID tracking.\n",
        "- create quality check plots and video.\n",
        "\n",
        "### To create a full maDLC pipeline please see our full docs: https://deeplabcut.github.io/DeepLabCut/README.html\n",
        "- Of interest is a full how-to for maDLC: https://deeplabcut.github.io/DeepLabCut/docs/maDLC_UserGuide.html\n",
        "- a quick guide to maDLC: https://deeplabcut.github.io/DeepLabCut/docs/tutorial.html\n",
        "- a demo COLAB for how to use maDLC on your own data: https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/COLAB/COLAB_maDLC_TrainNetwork_VideoAnalysis.ipynb\n",
        "\n",
        "### To get started, please go to \"Runtime\" ->\"change runtime type\"->select \"Python3\", and then select \"GPU\"\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "‼️ **Attention: this demo is for maDLC, which is version 2.2**\n"
      ],
      "metadata": {
        "id": "xOe2hvy85EVP"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NXmLeZBX45Oe"
      },
      "outputs": [],
      "source": [
        "# Install DLC version 2.2-2.3 (pre DLC3):\n",
        "!pip install \"deeplabcut[tf]\""
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import deeplabcut\n",
        "import os"
      ],
      "metadata": {
        "id": "TlhrVFKN8euh"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wid0GTGMAEnZ"
      },
      "source": [
        "## Important - Restart the Runtime for the updated packages to be imported!\n",
        "\n",
        "PLEASE, click \"restart runtime\" from the output above before proceeding!\n",
        "\n",
        "No information needs edited in the cells below, you can simply click run on each:\n",
        "\n",
        "### Download our Demo Project from our server:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "PusLdqbqJi60",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "dbe30821-d3a7-443f-de74-6cb0bee49aac"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Downloading demo-me-2021-07-14.zip...\n"
          ]
        }
      ],
      "source": [
        "# Download our demo project:\n",
        "import requests\n",
        "from io import BytesIO\n",
        "from zipfile import ZipFile\n",
        "\n",
        "url_record = \"https://zenodo.org/api/records/7883589\"\n",
        "response = requests.get(url_record)\n",
        "if response.status_code == 200:\n",
        "    file = response.json()[\"files\"][0]\n",
        "    title = file[\"key\"]\n",
        "    print(f\"Downloading {title}...\")\n",
        "    with requests.get(file[\"links\"][\"self\"], stream=True) as r:\n",
        "        with ZipFile(BytesIO(r.content)) as zf:\n",
        "            zf.extractall(path=\"/content\")\n",
        "else:\n",
        "    raise ValueError(f\"The URL {url_record} could not be reached.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8iXtySnQB0BE"
      },
      "source": [
        "## Analyze a novel 3 mouse video with our maDLC DLCRNet, pretrained on 3 mice data\n",
        "\n",
        "In one step, since `auto_track=True` you extract detections and association costs, create tracklets, & stitch them. We can use this to compare to the transformer-guided tracking below.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "odYrU3o8BSAr"
      },
      "outputs": [],
      "source": [
        "project_path = \"/content/demo-me-2021-07-14\"\n",
        "config_path = os.path.join(project_path, \"config.yaml\")\n",
        "video = os.path.join(project_path, \"videos\", \"videocompressed1.mp4\")"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "deeplabcut.analyze_videos(config_path,[video],\n",
        "                          shuffle=0, videotype=\"mp4\",\n",
        "                          auto_track=True)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 520
        },
        "id": "U_351Hkv81X-",
        "outputId": "f7c30461-101f-47b6-c04f-15809aa5a4bb"
      },
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Using snapshot-20000 for model /content/demo-me-2021-07-14/dlc-models/iteration-0/demoJul14-trainset95shuffle0\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/tensorflow/python/keras/engine/base_layer_v1.py:1694: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.\n",
            "  warnings.warn('`layer.apply` is deprecated and '\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Activating extracting of PAFs\n",
            "Starting to analyze %  /content/demo-me-2021-07-14/videos/videocompressed1.mp4\n",
            "Loading  /content/demo-me-2021-07-14/videos/videocompressed1.mp4\n",
            "Duration of video [s]:  77.67 , recorded with  30.0 fps!\n",
            "Overall # of frames:  2330  found with (before cropping) frame dimensions:  640 480\n",
            "Starting to extract posture from the video(s) with batchsize: 8\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 2330/2330 [00:39<00:00, 58.83it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Video Analyzed. Saving results in /content/demo-me-2021-07-14/videos...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/deeplabcut/utils/auxfun_multianimal.py:83: UserWarning: default_track_method` is undefined in the config.yaml file and will be set to `ellipse`.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Using snapshot-20000 for model /content/demo-me-2021-07-14/dlc-models/iteration-0/demoJul14-trainset95shuffle0\n",
            "Processing...  /content/demo-me-2021-07-14/videos/videocompressed1.mp4\n",
            "Analyzing /content/demo-me-2021-07-14/videos/videocompressed1DLC_dlcrnetms5_demoJul14shuffle0_20000.h5\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 2330/2330 [00:02<00:00, 1088.72it/s]\n",
            "2330it [00:06, 342.29it/s] \n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "The tracklets were created (i.e., under the hood deeplabcut.convert_detections2tracklets was run). Now you can 'refine_tracklets' in the GUI, or run 'deeplabcut.stitch_tracklets'.\n",
            "Processing...  /content/demo-me-2021-07-14/videos/videocompressed1.mp4\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 4/4 [00:00<00:00, 1488.53it/s]\n",
            "/usr/local/lib/python3.11/dist-packages/deeplabcut/refine_training_dataset/stitch.py:934: FutureWarning: Starting with pandas version 3.0 all arguments of to_hdf except for the argument 'path_or_buf' will be keyword-only.\n",
            "  df.to_hdf(output_name, \"tracks\", format=\"table\", mode=\"w\")\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "The videos are analyzed. Time to assemble animals and track 'em... \n",
            " Call 'create_video_with_all_detections' to check multi-animal detection quality before tracking.\n",
            "If the tracking is not satisfactory for some videos, consider expanding the training set. You can use the function 'extract_outlier_frames' to extract a few representative outlier frames.\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'DLC_dlcrnetms5_demoJul14shuffle0_20000'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zmdSLRTOER00"
      },
      "source": [
        "### Next, you compute the local, spatio-temporal grouping and track body part assemblies frame-by-frame:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F-d6kXqnGeUP"
      },
      "source": [
        "## Create a pretty video output:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "aTRbuUQ1FBO0",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "0d182f64-512d-463d-a997-226c7199b724"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Filtering with median model /content/demo-me-2021-07-14/videos/videocompressed1.mp4\n",
            "Saving filtered csv poses!\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/deeplabcut/post_processing/filtering.py:298: FutureWarning: Starting with pandas version 3.0 all arguments of to_hdf except for the argument 'path_or_buf' will be keyword-only.\n",
            "  data.to_hdf(outdataname, \"df_with_missing\", format=\"table\", mode=\"w\")\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Starting to process video: /content/demo-me-2021-07-14/videos/videocompressed1.mp4\n",
            "Loading /content/demo-me-2021-07-14/videos/videocompressed1.mp4 and data.\n",
            "Duration of video [s]: 77.67, recorded with 30.0 fps!\n",
            "Overall # of frames: 2330 with cropped frame dimensions: 640 480\n",
            "Generating frames and creating video.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/deeplabcut/utils/make_labeled_video.py:140: FutureWarning: DataFrame.groupby with axis=1 is deprecated. Do `frame.T.groupby(...)` without axis instead.\n",
            "  Dataframe.groupby(level=\"individuals\", axis=1).size().values // 3\n",
            "100%|██████████| 2330/2330 [00:31<00:00, 73.04it/s]\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[True]"
            ]
          },
          "metadata": {},
          "execution_count": 8
        }
      ],
      "source": [
        "#Filter the predictions to remove small jitter, if desired:\n",
        "deeplabcut.filterpredictions(config_path, [video], shuffle=0, videotype=\"mp4\")\n",
        "deeplabcut.create_labeled_video(\n",
        "    config_path,\n",
        "    [video],\n",
        "    videotype=\"mp4\",\n",
        "    shuffle=0,\n",
        "    color_by=\"individual\",\n",
        "    keypoints_only=False,\n",
        "    draw_skeleton=True,\n",
        "    filtered=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AYNlrgeNUG4U"
      },
      "source": [
        "Now, on the left panel if you click the folder icon, you will see the project folder \"demo-me..\"; click on this and go into \"videos\" and you can find the \"..._id_labeled.mp4\" video, which you can double-click on to download and inspect!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n7GWMBJUA9x5"
      },
      "source": [
        "### Create Plots of your data:\n",
        "\n",
        "> after running, you can look in \"videos\", \"plot-poses\" to check out the trajectories! (sometimes you need to click the folder refresh icon to see it). Within the folder, for example, see plotmus1.png to vide the bodyparts over time vs. pixel position.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "7w9BDIA7BB_i",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "a163087d-cbcb-4e4d-f461-2e24ed19a80b"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loading  /content/demo-me-2021-07-14/videos/videocompressed1.mp4 and data.\n",
            "Plots created! Please check the directory \"plot-poses\" within the video directory\n"
          ]
        }
      ],
      "source": [
        "deeplabcut.plot_trajectories(config_path, [video], shuffle=0,videotype=\"mp4\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l7BJQq7nxHVz"
      },
      "source": [
        "# Transformer for reID\n",
        "\n",
        "while the tracking here is very good without using the transformer, we want to demo the workflow for you!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "5xlO6TVYxQWc",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "a433221f-0390-4028-fe68-be0b90adad48"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Using snapshot-20000 for model /content/demo-me-2021-07-14/dlc-models/iteration-0/demoJul14-trainset95shuffle0\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/tensorflow/python/keras/engine/base_layer_v1.py:1694: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.\n",
            "  warnings.warn('`layer.apply` is deprecated and '\n",
            "/usr/local/lib/python3.11/dist-packages/tensorflow/python/keras/engine/base_layer_v1.py:1694: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.\n",
            "  warnings.warn('`layer.apply` is deprecated and '\n",
            "/usr/local/lib/python3.11/dist-packages/tensorflow/python/keras/engine/base_layer_v1.py:1694: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.\n",
            "  warnings.warn('`layer.apply` is deprecated and '\n",
            "/usr/local/lib/python3.11/dist-packages/tensorflow/python/keras/engine/base_layer_v1.py:1694: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.\n",
            "  warnings.warn('`layer.apply` is deprecated and '\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Activating extracting of PAFs\n",
            "Starting to analyze %  /content/demo-me-2021-07-14/videos/videocompressed1.mp4\n",
            "Loading  /content/demo-me-2021-07-14/videos/videocompressed1.mp4\n",
            "Duration of video [s]:  77.67 , recorded with  30.0 fps!\n",
            "Overall # of frames:  2330  found with (before cropping) frame dimensions:  640 480\n",
            "Starting to extract posture\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 2330/2330 [01:18<00:00, 29.78it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "If the tracking is not satisfactory for some videos, consider expanding the training set. You can use the function 'extract_outlier_frames' to extract a few representative outlier frames.\n",
            "Epoch 10, train acc: 0.61\n",
            "Epoch 10, test acc 0.45\n",
            "Epoch 20, train acc: 0.74\n",
            "Epoch 20, test acc 0.65\n",
            "Epoch 30, train acc: 0.78\n",
            "Epoch 30, test acc 0.55\n",
            "Epoch 40, train acc: 0.76\n",
            "Epoch 40, test acc 0.50\n",
            "Epoch 50, train acc: 0.85\n",
            "Epoch 50, test acc 0.55\n",
            "Epoch 60, train acc: 0.84\n",
            "Epoch 60, test acc 0.60\n",
            "Epoch 70, train acc: 0.85\n",
            "Epoch 70, test acc 0.55\n",
            "Epoch 80, train acc: 0.79\n",
            "Epoch 80, test acc 0.55\n",
            "Epoch 90, train acc: 0.88\n",
            "Epoch 90, test acc 0.55\n",
            "Epoch 100, train acc: 0.84\n",
            "Epoch 100, test acc 0.55\n",
            "loading params\n",
            "Processing...  /content/demo-me-2021-07-14/videos/videocompressed1.mp4\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 4/4 [00:00<00:00, 483.21it/s]\n",
            "/usr/local/lib/python3.11/dist-packages/deeplabcut/refine_training_dataset/stitch.py:934: FutureWarning: Starting with pandas version 3.0 all arguments of to_hdf except for the argument 'path_or_buf' will be keyword-only.\n",
            "  df.to_hdf(output_name, \"tracks\", format=\"table\", mode=\"w\")\n"
          ]
        }
      ],
      "source": [
        "deeplabcut.transformer_reID(\n",
        "    config_path,\n",
        "    [video],\n",
        "    shuffle=0,\n",
        "    videotype=\"mp4\",\n",
        "    track_method=\"ellipse\",\n",
        "    n_triplets=100,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uO_yoqN7xiBT"
      },
      "source": [
        "now we can make another video with the transformer-guided tracking:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "MBMbRFEMxmi4",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "5ca4357a-c8e1-46c6-ecad-141bfce48cc5"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loading  /content/demo-me-2021-07-14/videos/videocompressed1.mp4 and data.\n",
            "Plots created! Please check the directory \"plot-poses\" within the video directory\n"
          ]
        }
      ],
      "source": [
        "deeplabcut.plot_trajectories(\n",
        "    config_path,\n",
        "    [video],\n",
        "    shuffle=0,\n",
        "    videotype=\"mp4\",\n",
        "    track_method=\"transformer\",\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "vx3e-r1CoXaX",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "46cdbd39-d1f6-4b78-abba-7e979740f2a2"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Starting to process video: /content/demo-me-2021-07-14/videos/videocompressed1.mp4\n",
            "Loading /content/demo-me-2021-07-14/videos/videocompressed1.mp4 and data.\n",
            "Duration of video [s]: 77.67, recorded with 30.0 fps!\n",
            "Overall # of frames: 2330 with cropped frame dimensions: 640 480\n",
            "Generating frames and creating video.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/deeplabcut/utils/make_labeled_video.py:140: FutureWarning: DataFrame.groupby with axis=1 is deprecated. Do `frame.T.groupby(...)` without axis instead.\n",
            "  Dataframe.groupby(level=\"individuals\", axis=1).size().values // 3\n",
            "100%|██████████| 2330/2330 [00:31<00:00, 73.75it/s]\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[True]"
            ]
          },
          "metadata": {},
          "execution_count": 12
        }
      ],
      "source": [
        "deeplabcut.create_labeled_video(\n",
        "    config_path,\n",
        "    [video],\n",
        "    videotype=\"mp4\",\n",
        "    shuffle=0,\n",
        "    color_by=\"individual\",\n",
        "    keypoints_only=False,\n",
        "    draw_skeleton=True,\n",
        "    track_method=\"transformer\"\n",
        ")"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "COLAB_transformer_reID.ipynb",
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100",
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
